{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "This is the test code for visualizing poisoned training on CIFAR10, MNIST, using dataset class of torchvision.datasets.DatasetFolder torchvision.datasets.CIFAR10 torchvision.datasets.MNIST.\n",
    "Attack method is Refool.\n",
    "'''\n",
    "\n",
    "import os\n",
    "import cv2\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import Dataset, dataloader\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import CIFAR10, MNIST, DatasetFolder\n",
    "from torchvision.transforms import Compose, ToTensor, PILToTensor, RandomHorizontalFlip\n",
    "from torch.utils.data import DataLoader\n",
    "import core\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "global_seed = 666\n",
    "deterministic = True\n",
    "torch.manual_seed(global_seed)\n",
    "reflection_data_dir = \"/data/ganguanhao/datasets/VOCdevkit/VOC2012/JPEGImages/\"\n",
    "def read_image(img_path, type=None):\n",
    "    img = cv2.imread(img_path)\n",
    "    if type is None:        \n",
    "        return img\n",
    "    elif isinstance(type,str) and type.upper() == \"RGB\":\n",
    "        return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
    "    elif isinstance(type,str) and type.upper() == \"GRAY\":\n",
    "        return cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)\n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "    \n",
    "reflection_image_path = os.listdir(reflection_data_dir)\n",
    "reflection_images = [read_image(os.path.join(reflection_data_dir,img_path)) for img_path in reflection_image_path[:200]]\n",
    "\n",
    "def show_img(x):\n",
    "    img = cv2.cvtColor(x.permute(1,2,0).numpy(),cv2.COLOR_BGR2RGB)\n",
    "    plt.imshow(img)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform_train = Compose([\n",
    "    transforms.ToPILImage(),\n",
    "    # transforms.Resize((32,32)),\n",
    "    RandomHorizontalFlip(),\n",
    "    ToTensor(),\n",
    "])\n",
    "transform_test = Compose([\n",
    "    transforms.ToPILImage(),\n",
    "    # transforms.Resize((32,32)),\n",
    "    ToTensor(),\n",
    "])\n",
    "\n",
    "trainset = DatasetFolder(\n",
    "    root='/data/ganguanhao/datasets/cifar10-folder/trainset',\n",
    "    loader=cv2.imread,\n",
    "    extensions=('png',),\n",
    "    transform=transform_train,\n",
    "    target_transform=None,\n",
    "    is_valid_file=None)\n",
    "\n",
    "testset = DatasetFolder(\n",
    "    root='/data/ganguanhao/datasets/cifar10-folder/testset',\n",
    "    loader=cv2.imread,\n",
    "    extensions=('png',),\n",
    "    transform=transform_train,\n",
    "    target_transform=None,\n",
    "    is_valid_file=None)\n",
    "\n",
    "refool= core.Refool(\n",
    "    train_dataset=trainset,\n",
    "    test_dataset=testset,\n",
    "    model=core.models.ResNet(18),\n",
    "    loss=nn.CrossEntropyLoss(),\n",
    "    y_target=1,\n",
    "    poisoned_rate=0.05,\n",
    "    poisoned_transform_train_index=0,\n",
    "    poisoned_transform_test_index=0,\n",
    "    poisoned_target_transform_index=0,\n",
    "    schedule=None,\n",
    "    seed=global_seed,\n",
    "    deterministic=deterministic,\n",
    "    reflection_candidates = reflection_images,\n",
    ")\n",
    "\n",
    "\n",
    "poisoned_train_dataset, poisoned_test_dataset = refool.get_poisoned_dataset()\n",
    "print(\"============train_images============\")\n",
    "for i in range(10):\n",
    "    show_img(poisoned_train_dataset[i][0])\n",
    "print(\"============test_images============\")\n",
    "for i in range(10):\n",
    "    show_img(poisoned_test_dataset[i][0])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform_train = Compose([\n",
    "    RandomHorizontalFlip(),\n",
    "    ToTensor(),\n",
    "])\n",
    "transform_test = Compose([\n",
    "    ToTensor(),\n",
    "])\n",
    "trainset = CIFAR10(\n",
    "    root='/data/ganguanhao/datasets',\n",
    "    transform=transform_train,\n",
    "    target_transform=None,\n",
    "    train=True,\n",
    "    download=True)\n",
    "testset = CIFAR10(\n",
    "    root='/data/ganguanhao/datasets',\n",
    "    transform=transform_test,\n",
    "    target_transform=None,\n",
    "    train=False,\n",
    "    download=True)\n",
    "\n",
    "refool= core.Refool(\n",
    "    train_dataset=trainset,\n",
    "    test_dataset=testset,\n",
    "    model=core.models.ResNet(18),\n",
    "    loss=nn.CrossEntropyLoss(),\n",
    "    y_target=1,\n",
    "    poisoned_rate=0.05,\n",
    "    poisoned_transform_train_index=0,\n",
    "    poisoned_transform_test_index=0,\n",
    "    poisoned_target_transform_index=0,\n",
    "    schedule=None,\n",
    "    seed=global_seed,\n",
    "    deterministic=deterministic,\n",
    "    reflection_candidates = reflection_images,\n",
    ")\n",
    "\n",
    "\n",
    "poisoned_train_dataset, poisoned_test_dataset = refool.get_poisoned_dataset()\n",
    "print(\"============train_images============\")\n",
    "for i in range(10):\n",
    "    show_img(poisoned_train_dataset[i][0])\n",
    "print(\"============test_images============\")\n",
    "for i in range(10):\n",
    "    show_img(poisoned_test_dataset[i][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform_train = Compose([\n",
    "    RandomHorizontalFlip(),\n",
    "    ToTensor(),\n",
    "])\n",
    "transform_test = Compose([\n",
    "    ToTensor(),\n",
    "])\n",
    "trainset = MNIST(\n",
    "    root='/data/ganguanhao/datasets',\n",
    "    transform=transform_train,\n",
    "    target_transform=None,\n",
    "    train=True,\n",
    "    download=True)\n",
    "testset = MNIST(\n",
    "    root='/data/ganguanhao/datasets',\n",
    "    transform=transform_test,\n",
    "    target_transform=None,\n",
    "    train=False,\n",
    "    download=True)\n",
    "\n",
    "loader = DataLoader(trainset,)\n",
    "\n",
    "refool= core.Refool(\n",
    "    train_dataset=trainset,\n",
    "    test_dataset=testset,\n",
    "    model=core.models.BaselineMNISTNetwork(),\n",
    "    loss=nn.CrossEntropyLoss(),\n",
    "    y_target=1,\n",
    "    poisoned_rate=0.05,\n",
    "    poisoned_transform_train_index=0,\n",
    "    poisoned_transform_test_index=0,\n",
    "    poisoned_target_transform_index=0,\n",
    "    schedule=None,\n",
    "    seed=global_seed,\n",
    "    deterministic=deterministic,\n",
    "    reflection_candidates = reflection_images,\n",
    ")\n",
    "\n",
    "poisoned_train_dataset, poisoned_test_dataset = refool.get_poisoned_dataset()\n",
    "print(\"============train_images============\")\n",
    "for i in range(10):\n",
    "    show_img(poisoned_train_dataset[i][0])\n",
    "print(\"============test_images============\")\n",
    "for i in range(10):\n",
    "    show_img(poisoned_test_dataset[i][0])"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "0f3a347edc68809802ca80b72bd825e96565764d57572064874c38c4d7a65333"
  },
  "kernelspec": {
   "display_name": "Python 3.8.12 64-bit ('tc18': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
